import os
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from utility.parser import parse_args
from utility.norm import build_sim, build_knn_normalized_graph
from torch_geometric.utils import dropout_adj
args = parse_args()
import torch
import torch.nn.functional as F
from torch.nn import Parameter
# from torch_scatter import scatter_mean
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax


class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize=True, bias=True, aggr='mean', **kwargs):
        super(SAGEConv, self).__init__(aggr=aggr, **kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x, edge_index, weight_vector, size=None):
        self.weight_vector = weight_vector
        return self.propagate(edge_index, size=size, x=x)

    def message(self, x_j):
        return x_j * self.weight_vector

    def update(self, aggr_out):
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


class GATConv(MessagePassing):
    def __init__(self, in_channels, out_channels, self_loops=False):
        super(GATConv, self).__init__(aggr='add')#, **kwargs)
        self.self_loops = self_loops
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x, edge_index, size=None):
        edge_index, _ = remove_self_loops(edge_index)
        if self.self_loops: # false
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        return self.propagate(edge_index, size=size, x=x)


    def message(self, x, edge_index, edge_index_i, edge_index_j ,x_i, x_j, size_i):
        '''
        pyg默认第一行是源node 也就是neighboring node 第二行是目的node 也就是aggregate node 
        把第1行结点的信息传给第2行结点
        pyg信息传递方向 x_j->x_i
        x_j = x[edge_index[0]]
        x_i = x[edge_index[1]]
        edge_index
        tensor([[18940, 18940, 18940,  ...,  8560,  8560,  8560], # user
                [20138, 25737, 21714,  ..., 26402, 22031, 25940]], device='cuda:0') # item
        
        所以感觉原作者实现的equation 3的方向是错的
        论文中的意思是把item embedding卷积给user embedding              
                
        刚才找了一个教程 https://blog.csdn.net/LittleSeedling/article/details/125020621        
        
        文中对edge_index的描述是这样的 
        row, col = edge_index # row, col is the [out index] and [in index]
        row表示出边的顶点 col表示入边的顶点。
        按照这个说法 默认应该是第一行是neighboring node 第二行是目的node
        pyg官方教程中对信息流动的描述中 默认flow应该是source_to_target 应该是跟这个相对应的吧
        MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
        那么j表示源节点 i表示目的结点 这个应该还是不变的
        
        '''
        
        '''
        那照这样实现是把preference卷给item feature
        总结一下这个GATConv
        就是把feature从user侧卷到item侧
        而且也没添加self-loop 我怀疑这个作者是不是写错代码了
        按照这样卷的话 岂不是直接把item的feature(content information)卷丢了
        至于卷积时各个neighbor的权重 就是计算了一个余弦相似度 x进来的时候已经按行normalize成单位向量了
        preference = F.normalize(self.preference)
        features = F.normalize(features)
        然后alpha也做一个softmax 相当于对同一个target node
        所有neighbor的权重相加为1
        '''
        
        self.alpha = torch.mul(x_i, x_j).sum(dim=-1)
        # alpha = F.tanh(alpha)
        # self.alpha = F.leaky_relu(self.alpha)
        # alpha = torch.sigmoid(alpha)
        self.alpha = softmax(self.alpha, edge_index_i, num_nodes=size_i) #equation 2
        # Sample attention coefficients stochastically.
        # alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j*self.alpha.view(-1,1)
        # return x_j * alpha.view(-1, self.heads, 1)

    def update(self, aggr_out):
        return aggr_out

class EGCN(torch.nn.Module):
    def __init__(self, num_user, num_item, dim_E, aggr_mode, has_act, has_norm):
        super(EGCN, self).__init__()
        self.num_user = num_user
        self.num_item = num_item
        self.dim_E = dim_E
        self.aggr_mode = aggr_mode
        self.has_act = has_act
        self.has_norm = has_norm
        self.id_embedding = nn.Parameter(nn.init.xavier_normal_(
            torch.rand((num_user+num_item, dim_E))))
        self.conv_embed_1 = SAGEConv(dim_E, dim_E, aggr=aggr_mode)
        self.conv_embed_2 = SAGEConv(dim_E, dim_E, aggr=aggr_mode)

    def forward(self, edge_index, weight_vector):
        x = self.id_embedding
        edge_index = torch.cat((edge_index, edge_index[[1, 0]]), dim=1)

        if self.has_norm:
            x = F.normalize(x)

        x_hat_1 = self.conv_embed_1(x, edge_index, weight_vector)

        if self.has_act:
            x_hat_1 = F.leaky_relu_(x_hat_1)

        x_hat_2 = self.conv_embed_2(x_hat_1, edge_index, weight_vector)
        if self.has_act:
            x_hat_2 = F.leaky_relu_(x_hat_2)

        return x + x_hat_1 + x_hat_2


class CGCN(torch.nn.Module):
    def __init__(self, features, num_user, num_item, dim_C, aggr_mode, num_routing, has_act, has_norm):
        super(CGCN, self).__init__()
        self.num_user = num_user
        self.num_item = num_item
        # add
        self.aggr_mode = aggr_mode
        # 卷积的次数 也就是层数吧 默认3层
        self.num_routing = num_routing
        # False
        self.has_act = has_act
        # True
        self.has_norm = has_norm
        self.dim_C = dim_C
        self.preference = nn.Parameter(
            nn.init.xavier_normal_(torch.rand((num_user, dim_C))))
        self.conv_embed_1 = GATConv(self.dim_C, self.dim_C)
        self.dim_feat = features.size(1)
        self.features = features
        self.MLP = nn.Linear(self.dim_feat, self.dim_C)

    def forward(self, edge_index, mask):

        '''
        # 这加一个leaky_relu
        # 是想增加一下提取特征的表现力? MLP+relu搞成一个non-linear层
        # 感觉这个东西放在这没什么大用
        刚才查了一下 这个leaky_relu
        torch.nn.functional.leaky_relu(input, negative_slope=0.01, inplace=False)
        negative_slope=0.01 感觉这个值还是比较小的
        这个会不会鼓励模型 不去生成负值的feature
        如果item feature不能取负值的话 我感觉会对MF的学习造成影响
        '''
        # equation 1
        features = F.leaky_relu(self.MLP(self.features.mean(dim=0).tile(self.num_item,1))) if mask else F.leaky_relu(self.MLP(self.features))
        preference = self.preference.mean(dim=0).tile(self.num_user,1) if mask else self.preference
        # 这个可以理解 但是这个限制老实讲是有点猛的
        if self.has_norm:
            preference = F.normalize(preference)
            features = F.normalize(features)

        for i in range(self.num_routing):
            '''
            我原本推测从user卷给item 是不是user侧的preference不变
            结果刚才debug的结果是全为0
            也就是说x_hat_1[:self.num_user]这一整个就是0
            刚才debug了一下 这玩意一直都是0
            相当于equation3/4都没起作用 就是初始的user preference
            '''
            x = torch.cat((preference, features), dim=0)
            x_hat_1 = self.conv_embed_1(x, edge_index)
            preference = preference + x_hat_1[:self.num_user] # equation 3

            if self.has_norm:
                preference = F.normalize(preference)

        x = torch.cat((preference, features), dim=0)
        edge_index = torch.cat((edge_index, edge_index[[1, 0]]), dim=1)

        x_hat_1 = self.conv_embed_1(x, edge_index)

        if self.has_act:
            x_hat_1 = F.leaky_relu_(x_hat_1)

        # 最后做了一个双向的一维卷积(GAT 感觉只有这个卷积起到作用了)
        # 注意这里的卷积没有加self-loop 所以需要x+x_hat_1
        # x保留了原本的feature
        # self.conv_embed_1.alpha.view(-1, 1)也是双向23w条
        return x + x_hat_1, self.conv_embed_1.alpha.view(-1, 1)


class GRCN(torch.nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, 
                image_feats, text_feats, adj ,edge_index,
                aggr_mode='add', weight_mode='confid', fusion_mode='concat',
                num_routing=3, dropout=0,
                has_act=False, has_norm=True,
                pruning=True):
        super(GRCN, self).__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.weight_mode = weight_mode
        self.fusion_mode = fusion_mode
        self.weight = torch.tensor([[1.0], [-1.0]]).cuda()
        self.dropout = dropout
        dim_E,dim_C=embedding_dim,embedding_dim

        '''
        这个好像是把edge_index反转过来
        本来读进来是[E,2] 然后转置一下得到[2,E]
        丢给pyg使用
        没有i_j互换来构建双选的edge_index
        11w条边是传进来的user到item的单向边的条数
        '''
        self.edge_index = torch.tensor(edge_index).t().contiguous().cuda()
        '''
        模态卷积完了之后给一个weight 然后用那个weight去卷积id embedding
        '''
        self.id_gcn = EGCN(n_users, n_items, dim_E,
                           aggr_mode, has_act, has_norm)
        # feat进来之前已经被转成tensor 放到cuda上了
        self.v_feat = torch.Tensor(image_feats).cuda()
        self.t_feat = torch.Tensor(text_feats).cuda()
        self.pruning = pruning

        num_model = 2
        self.v_gcn = CGCN(self.v_feat, n_users, n_items,
                            dim_C, aggr_mode, num_routing, has_act, has_norm)

        self.t_gcn = CGCN(self.t_feat, n_users, n_items, 
                            dim_C, aggr_mode, num_routing, has_act, has_norm)

        self.model_specific_conf = nn.Parameter(
            nn.init.xavier_normal_(torch.rand((n_users+n_items, num_model))))

    def forward(self,training=1):
        weight = None
        content_rep = None
        num_modal = 2
        edge_index, _ = dropout_adj(self.edge_index, p=self.dropout)
        visual_mask=True if training==2 or training==4 else False
        textual_mask=True if training==3 or training==4 else False  
        
        # 做模态卷积时 不仅会得到一个v_rep 每条单向边(一条边i_j和j_i的weight不同)还会得到一个weight
        # 返回的weight [237102, 1]
        v_rep, weight_v = self.v_gcn(edge_index,visual_mask)
        weight = weight_v
        content_rep = v_rep

        t_rep, weight_t = self.t_gcn(edge_index,textual_mask)
        content_rep = torch.cat((content_rep, t_rep), dim=1)
        if self.weight_mode == 'mean':
            weight = weight + weight_t
        else:
            weight = torch.cat((weight, weight_t), dim=1)

        if self.weight_mode == 'mean':
            weight = weight/num_modal

        elif self.weight_mode == 'max':
            weight, _ = torch.max(weight, dim=1)
            weight = weight.view(-1, 1)

        # equation 7
        elif self.weight_mode == 'confid':
            confidence = torch.cat(
                (self.model_specific_conf[edge_index[0]], self.model_specific_conf[edge_index[1]]), dim=0)
            weight = weight * confidence
            weight, _ = torch.max(weight, dim=1)
            weight = weight.view(-1, 1)

        if self.pruning:
            weight = torch.relu(weight)

        id_rep = self.id_gcn(edge_index, weight)

        if self.fusion_mode == 'concat':
            representation = torch.cat((id_rep, content_rep), dim=1)
        elif self.fusion_mode == 'id':
            representation = id_rep
        elif self.fusion_mode == 'mean':
            representation = (id_rep+v_rep+t_rep)/3

        user_rep = representation[:self.n_users]
        item_rep = representation[self.n_users:]
        return user_rep, item_rep
